"""
Command-line argument parsing.
"""

import argparse
from functools import partial

import tensorflow as tf

from .minibatchprox import MinibatchProx

def argument_parser():
    """
    Get an argument parser for a training script.
    """
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--pretrained', help='evaluate a pre-trained model', default=False, type=bool)
    parser.add_argument('--lam_reg', help='weight of regularization', default=0.1, type=float)
    parser.add_argument('--seed', help='random seed', default=0, type=int)
    parser.add_argument('--checkpoint', help='checkpoint directory', default='model_checkpoint')
    parser.add_argument('--classes', help='number of classes per inner task', default=5, type=int)
    parser.add_argument('--shots', help='number of examples per class', default=5, type=int)
    parser.add_argument('--train-shots', help='shots in a training batch', default=0, type=int)
    parser.add_argument('--inner-batch', help='inner batch size', default=5, type=int)
    parser.add_argument('--inner-iters', help='inner iterations', default=20, type=int)
    parser.add_argument('--replacement', help='sample with replacement', action='store_true')
    parser.add_argument('--learning-rate', help='Adam step size', default=1e-3, type=float)
    parser.add_argument('--meta-step', help='meta-training step size', default=0.1, type=float)
    parser.add_argument('--meta-step-final', help='meta-training step size by the end',
                        default=0.1, type=float)
    parser.add_argument('--meta-batch', help='meta-training batch size', default=1, type=int)
    parser.add_argument('--meta-iters', help='meta-training iterations', default=400000, type=int)
    parser.add_argument('--eval-batch', help='eval inner batch size', default=5, type=int)
    parser.add_argument('--eval-iters', help='eval inner iterations', default=50, type=int)
    parser.add_argument('--eval-samples', help='evaluation samples', default=20, type=int)
    parser.add_argument('--eval-interval', help='train steps per eval', default=10, type=int)
    parser.add_argument('--weight-decay', help='weight decay rate', default=1, type=float)
    parser.add_argument('--transductive', help='evaluate all samples at once', default=True, type=bool)
    parser.add_argument('--sgd', help='use vanilla SGD instead of Adam', action='store_true')
    parser.add_argument('--x_dim', type=str, default="84,84,3", metavar='XDIM',help='input image dims')
    parser.add_argument('--ratio', type=float, default=1.0, metavar='RATIO', help="ratio of labeled data (for semi-supervised setting")
    parser.add_argument('--pkl', type=int, default=1, metavar='PKL', help="1 for use pkl dataset, 0 for original images")
    # parser.add_argument('--DATASET', type=str, default="tieredimagenet", help='datasetname')# must set from 'tieredimagenet' or 'miniimagenet'
    # parser.add_argument('--DATA_DIR', type=str, default="../../tiered-imagenet/", help='dataset path')
    parser.add_argument('--DATASET', type=str, default="miniimagenet", help='datasetname')# must set from 'tieredimagenet' or 'miniimagenet'
    parser.add_argument('--DATA_DIR', type=str, default="/export/home/dataset/bakminiimagnet", help='dataset path')
    return parser






def model_kwargs(parsed_args):
    """
    Build the kwargs for model constructors from the
    parsed command-line arguments.
    """
    res = {'learning_rate': parsed_args.learning_rate}
    if parsed_args.sgd:
        res['optimizer'] = tf.train.GradientDescentOptimizer
    return res


def data_kwargs(parsed_args):
    """
    Build kwargs for the train() function from the parsed
    command-line arguments.
    """
    return {
        'x_dim':parsed_args.x_dim,
        'ratio': parsed_args.ratio,
        #'pkl': parsed_args.pkl,
        'seed':parsed_args.seed,
        'DATA_DIR':parsed_args.DATA_DIR
    }


def train_kwargs(parsed_args):
    """
    Build kwargs for the train() function from the parsed
    command-line arguments.
    """
    return {
        'num_classes': parsed_args.classes,
        'num_shots': parsed_args.shots,
        'train_shots': (parsed_args.train_shots or None),
        'inner_batch_size': parsed_args.inner_batch,
        'inner_iters': parsed_args.inner_iters,
        'replacement': parsed_args.replacement,
        'meta_step_size': parsed_args.meta_step,
        'meta_step_size_final': parsed_args.meta_step_final,
        'meta_batch_size': parsed_args.meta_batch,
        'meta_iters': parsed_args.meta_iters,
        'eval_inner_batch_size': parsed_args.eval_batch,
        'eval_inner_iters': parsed_args.eval_iters,
        'eval_interval': parsed_args.eval_interval,
        'weight_decay_rate': parsed_args.weight_decay,
        'transductive': parsed_args.transductive,
        'lam_reg': parsed_args.lam_reg,
        'MinibatchProx_m': MinibatchProx,
        'dataset_name': parsed_args.DATASET
    }

def evaluate_kwargs(parsed_args):
    """
    Build kwargs for the evaluate() function from the
    parsed command-line arguments.
    """
    return {
        'num_classes': parsed_args.classes,
        'num_shots': parsed_args.shots,
        'eval_inner_batch_size': parsed_args.eval_batch,
        'eval_inner_iters': parsed_args.eval_iters,
        'replacement': parsed_args.replacement,
        'weight_decay_rate': parsed_args.weight_decay,
        'num_samples': parsed_args.eval_samples,
        'transductive': parsed_args.transductive,
        'lam_reg': parsed_args.lam_reg,
        'MinibatchProx_m': MinibatchProx,
        'dataset_name': parsed_args.DATASET
    }

def _args_reptile(parsed_args):
    return MinibatchProx
